def matrix_zero(mat):
    if not mat:
        return mat
    m = len(mat)
    n = len(mat[0])
    m_z = [1] * m
    n_z = [1] * n

    for i in range(m):
        for j in range(n):
            if not mat[i][j]:
                m_z[i] = 0
                n_z[j] = 0

    for i in range(m):
        for j in range(n):
            if not m_z[i] or not n_z[j]:
                mat[i][j] = 0

mat = [[1, 1, 1],[1, 0, 1],[1, 1, 0]]
matrix_zero(mat)
print(mat)
